# %%

%matplotlib inline
import matplotlib.pyplot as plt, numpy as np, pims, pandas as pd, cv2 as cv, scipy.optimize as op, seaborn as sns

def gb_func(img):
    # filter image
    img = cv.medianBlur(img[:,:,0], 31)
    # apply adaptive threshold
    img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 3001, 1)
    # Extract surface of bead mound
    img = cv.Canny(img,100,100)
    return cv.Canny(img,100,100)

def oc_func(img):
    # filter image
    img = cv.medianBlur(img[:,:,2], 51)
    # apply adaptive threshold
    img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 5001, 1)
    radius = 100
    X, Y = np.meshgrid(np.linspace(-radius/2, radius/2, radius+1), np.linspace(-radius/2, radius/2, radius+1))
    kernel = np.where(np.sqrt(X**2 + Y**2)<radius/2, 1, 0).astype(np.uint8)
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # reverse closing to get rid of noise above bed
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # Extract surface of bead mound
    img = cv.Canny(img,100,100)
    return img

def gg_func(img):
    # filter image
    img = cv.medianBlur(img[:,:,0], 51)
    # apply adaptive threshold
    img = cv.adaptiveThreshold(img, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 5001, 1)
    radius = 100
    X, Y = np.meshgrid(np.linspace(-radius/2, radius/2, radius+1), np.linspace(-radius/2, radius/2, radius+1))
    kernel = np.where(np.sqrt(X**2 + Y**2)<radius/2, 1, 0).astype(np.uint8)
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # reverse closing to get rid of noise above bed
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # Extract surface of bead mound
    img = cv.Canny(img,100,100)
    return img

def ng_func(img):
    img1 = np.asarray(img[:,:,0])
    img1 = cv.medianBlur(img1, 11)
    img1 = cv.adaptiveThreshold(img1, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img2 = np.asarray(img[:,:,1])
    img2 = cv.medianBlur(img2, 11)
    img2 = cv.adaptiveThreshold(img2, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img3 = np.asarray(img[:,:,2])
    img3 = cv.medianBlur(img3, 11)
    img3 = cv.adaptiveThreshold(img3, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img = np.minimum(img1, img2, img3)
    radius = 100
    X, Y = np.meshgrid(np.linspace(-radius/2, radius/2, radius+1), np.linspace(-radius/2, radius/2, radius+1))
    kernel = np.where(np.sqrt(X**2 + Y**2)<radius/2, 1, 0).astype(np.uint8)
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # reverse closing to get rid of noise above bed
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # Extract surface of bead mound
    return cv.Canny(img,100,100)

def ngup_func(img):
    img1 = np.asarray(img[:,:,0])
    img1 = cv.medianBlur(img1, 11)
    img1 = cv.adaptiveThreshold(img1, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img2 = np.asarray(img[:,:,1])
    img2 = cv.medianBlur(img2, 11)
    img2 = cv.adaptiveThreshold(img2, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img3 = np.asarray(img[:,:,2])
    img3 = cv.medianBlur(img3, 11)
    img3 = cv.adaptiveThreshold(img3, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img = np.minimum(img1, img2, img3)
    radius = 100
    X, Y = np.meshgrid(np.linspace(-radius/2, radius/2, radius+1), np.linspace(-radius/2, radius/2, radius+1))
    kernel = np.where(np.sqrt(X**2 + Y**2)<radius/2, 1, 0).astype(np.uint8)
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # reverse closing to get rid of noise above bed
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # Extract surface of bead mound
    return cv.Canny(img,100,100)

def ls_func(img):
    img1 = np.asarray(img[:,:,0])
    img1 = cv.medianBlur(img1, 11)
    img1 = cv.adaptiveThreshold(img1, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img2 = np.asarray(img[:,:,1])
    img2 = cv.medianBlur(img2, 11)
    img2 = cv.adaptiveThreshold(img2, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img3 = np.asarray(img[:,:,2])
    img3 = cv.medianBlur(img3, 11)
    img3 = cv.adaptiveThreshold(img3, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 501, 11)
    img = np.minimum(img1, img2, img3)
    radius = 100
    X, Y = np.meshgrid(np.linspace(-radius/2, radius/2, radius+1), np.linspace(-radius/2, radius/2, radius+1))
    kernel = np.where(np.sqrt(X**2 + Y**2)<radius/2, 1, 0).astype(np.uint8)
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # reverse closing to get rid of noise above bed
    img = cv.morphologyEx(cv.bitwise_not(img), cv.MORPH_CLOSE, kernel, iterations=1)
    # Extract surface of bead mound
    return cv.Canny(img,100,100)

grains = ['gg', 'oc', 'ng', 'ngup', 'gb', 'ls']
fns = {'ng': 'ng/painted/*.jpg', 'ngup': 'ng/unpainted/*.JPG', 'gg': 'gg/*.JPG', 'oc1': 'oc/oc1/*.JPG', 'oc': 'oc/oc2/*.JPG', 'gb': 'gb/*.JPG', 'ls': 'ls/*.JPG'}
line_dims = {'gb': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 1100},
             'ng': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 900},
             'ngup': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 900},
             'gg': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 900},
             'oc1': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 900},
             'oc': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 900},
             'ls': {'x1': 1500, 'x2': -1500, 'y1': 0, 'y2': 800}}
mound_dims = {'gb': {'x1': 1250, 'x2': -1250, 'y1': 1100, 'y2': 1500},
              'ng': {'x1': 1100, 'x2': -1100, 'y1': 900,  'y2': 1450},
              'ngup': {'x1': 1100, 'x2': -1100, 'y1': 1000,  'y2': 1600},
              'gg': {'x1': 1000,  'x2': -1000, 'y1':  1100,  'y2': 1500},
              'oc1': {'x1': 900,  'x2': -900,  'y1': 900,  'y2': 1400},
              'oc': {'x1': 1100,  'x2': -1100,  'y1': 1100,  'y2': 1550},
              'ls': {'x1': 1000,  'x2': -1000,  'y1': 900,  'y2': 1700}}
offset = {'gb': [120, 120], 'ng': [120, 120], 'ngup': [120, 120], 'gg': [200, 200], 'oc1': [120, 120], 'oc': [120, 120], 'ls': [120, 120]}
func_calls = {'gb': gb_func, 'ng': ng_func, 'ngup': ngup_func, 'gg': gg_func, 'oc1': oc_func, 'oc': oc_func, 'ls': ls_func}

#%%
# extract portion of image with plumb line

# set up empty dicts
lean = {}
centerline = {}

# iterate over grain types
for key in grains:
    # pull up images
    print(fns[key])
    imgs = pims.ImageSequence(fns[key])
    lean[key] = np.zeros(imgs._count)
    centerline[key] = np.zeros(imgs._count)
    for j in range(imgs._count):
        plt.figure(j)
        # transform into 8 bit image
        img = np.asarray(imgs[j][line_dims[key]['y1']:line_dims[key]['y2'],line_dims[key]['x1']:line_dims[key]['x2']][:,:,1] / imgs[j][line_dims[key]['y1']:line_dims[key]['y2'],line_dims[key]['x1']:line_dims[key]['x2']][:,:,0])
        img[img>1] = 1
        img = np.asarray(img*255).astype(np.uint8)

        # Apply threshold and find plumb line
        ret, img = cv.threshold(img, 100, 255, cv.THRESH_BINARY)
        img = cv.Canny(img,100,100)

        # extract plumbline into a set of x,y coords
        ys, xs = np.where(img > 0)
        xs, ys = xs[ys.argsort()], ys[ys.argsort()]
        # take mean of x values for each y val
        y_out = np.unique(ys)
        x_out = np.array([xs[ys == y].mean() for y in y_out])
        # take average of mean x points with a rolling window as large as the image is tall
        x_out = pd.Series(np.asarray(x_out)).rolling(window=np.int(line_dims[key]['y2']), center=True, min_periods=1).mean()
        # fit a straight line to average points
        func = lambda x, b, m: m*x + b
        popt, pcov = op.curve_fit(func, x_out, y_out)

        plt.plot(xs, ys, '.')
        plt.plot(x_out, np.unique(ys))
        plt.plot(x_out, func(x_out, *popt), 'r-')

        lean[key][j] = np.degrees(np.arctan(popt[1]))
        centerline[key][j] = x_out.mean() + line_dims[key]['x1']

print(lean)
print(centerline)

# %%

AoR_l = {}
AoR_r = {}
AoRs = {}
count = 0
# for key in ['oc']:
for key in grains:
    # pull up images
    imgs = pims.ImageSequence(fns[key])

    AoR_l[key] = np.zeros(imgs._count)
    AoR_r[key] = np.zeros(imgs._count)
    for j in range(imgs._count):
        plt.figure(count*10 + j)
        plt.imshow(np.asarray(imgs[j][mound_dims[key]['y1']:mound_dims[key]['y2'],mound_dims[key]['x1']:mound_dims[key]['x2']]))

        img = func_calls[key](imgs[j][mound_dims[key]['y1']:mound_dims[key]['y2'],mound_dims[key]['x1']:mound_dims[key]['x2']])

        ys, xs = np.where(img > 0)
        xs, ys = xs[ys.argsort()], ys[ys.argsort()]
        x_out = np.unique(xs)
        y_out = np.array([ys[xs == x].min() for x in x_out])  
        # fit a line to either side of mound
        func = lambda x, b, m: m*x + b
        # left side
        x_l = x_out[x_out < centerline[key][j] - offset[key][0] - mound_dims[key]['x1']]
        y_l = y_out[x_out < centerline[key][j] - offset[key][0] - mound_dims[key]['x1']]
        popt, pcov = op.curve_fit(func, x_l, y_l)
        if lean[key][j] > 0: fac = 90 - lean[key][j]
        else: fac = lean[key][j] + 90
        AoR_l[key][j] = np.abs(np.degrees(np.arctan(popt[1]))) + fac

        plt.plot(x_l, y_l, 'r-')
        plt.plot(x_l, func(x_l, *popt), 'y-')

        # right side
        x_r = x_out[x_out > centerline[key][j] + offset[key][1] - mound_dims[key]['x1']]
        y_r = y_out[x_out > centerline[key][j] + offset[key][1] - mound_dims[key]['x1']]
        popt, pcov = op.curve_fit(func, x_r, y_r)
        AoR_r[key][j] = np.abs(np.degrees(np.arctan(popt[1]))) - fac

        plt.plot(x_r, y_r, 'b-')
        plt.plot(x_r, func(x_r, *popt), 'y-')

        # plt.savefig(fns[key][:-5] + 'fits/img_%i.png' % j)

    AoRs[key] = np.hstack([AoR_l[key], AoR_r[key]])
    if AoRs[key].size < 8: AoRs[key] = np.hstack([AoRs[key], [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]])
    if AoRs[key].size < 10: AoRs[key] = np.hstack([AoRs[key], [np.nan, np.nan, np.nan, np.nan]])
    if AoRs[key].size < 12: AoRs[key] = np.hstack([AoRs[key], [np.nan, np.nan]])
    
    count += 1

print(AoRs)

# %%

df = pd.DataFrame(AoRs).T
xf_mean = df.reset_index().groupby('index').apply(np.nanmean)
xf_std = df.reset_index().groupby('index').apply(np.nanstd)
plt.errorbar(xf_mean.index, xf_mean.values, yerr=xf_std.values)
# print(xf_mean, xf_std)
# plt.plot(xf_mean)
# %%

xf

# %%
